Explaining Text Classification

from explainer.explainers import feature_attributions_explainer, metrics_explainer
import numpy as np
from sklearn import datasets

all_categories = ['alt.atheism','comp.graphics','comp.os.ms-windows.misc','comp.sys.ibm.pc.hardware',
                  'comp.sys.mac.hardware','comp.windows.x', 'misc.forsale','rec.autos','rec.motorcycles',
                  'rec.sport.baseball','rec.sport.hockey','sci.crypt','sci.electronics','sci.med',
                  'sci.space','soc.religion.christian','talk.politics.guns','talk.politics.mideast',
                  'talk.politics.misc','talk.religion.misc']

selected_categories = ['alt.atheism','comp.graphics','rec.motorcycles','sci.space','talk.politics.misc']

X_train_text, Y_train = datasets.fetch_20newsgroups(subset="train", categories=selected_categories, return_X_y=True)
X_test_text , Y_test  = datasets.fetch_20newsgroups(subset="test", categories=selected_categories, return_X_y=True)

X_train_text = np.array(X_train_text)
X_test_text = np.array(X_test_text)

classes = np.unique(Y_train)
mapping = dict(zip(classes, selected_categories))

len(X_train_text), len(X_test_text), classes, mapping
(2720,
 1810,
 array([0, 1, 2, 3, 4]),
 {0: 'alt.atheism',
  1: 'comp.graphics',
  2: 'rec.motorcycles',
  3: 'sci.space',
  4: 'talk.politics.misc'})
print(Y_test)
[2 3 0 ... 3 2 3]

Vectorize Text Data

import sklearn
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer

vectorizer = TfidfVectorizer(max_features=50000)

vectorizer.fit(np.concatenate((X_train_text, X_test_text)))
X_train = vectorizer.transform(X_train_text)
X_test = vectorizer.transform(X_test_text)

X_train, X_test = X_train.toarray(), X_test.toarray()

X_train.shape, X_test.shape
((2720, 50000), (1810, 50000))

Define the Model

from tensorflow.keras.models import Sequential
from tensorflow.keras import layers

def create_model():
    return Sequential([
                        layers.Input(shape=X_train.shape[1:]),
                        layers.Dense(128, activation="relu"),
                        layers.Dense(64, activation="relu"),
                        layers.Dense(len(classes), activation="softmax"),
                    ])

model = create_model()

model.summary()

Compile and Train Model

model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
history = model.fit(X_train, Y_train, batch_size=256, epochs=5, validation_data=(X_test, Y_test))

Evaluate Model Performance

from sklearn.metrics import accuracy_score, classification_report

train_preds = model.predict(X_train)
test_preds = model.predict(X_test)

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=selected_categories))
Hide code cell output
 1/85 [..............................] - ETA: 7s

 7/85 [=>............................] - ETA: 0s

14/85 [===>..........................] - ETA: 0s

21/85 [======>.......................] - ETA: 0s

28/85 [========>.....................] - ETA: 0s

36/85 [===========>..................] - ETA: 0s

43/85 [==============>...............] - ETA: 0s

51/85 [=================>............] - ETA: 0s

59/85 [===================>..........] - ETA: 0s

66/85 [======================>.......] - ETA: 0s

73/85 [========================>.....] - ETA: 0s

80/85 [===========================>..] - ETA: 0s

85/85 [==============================] - 1s 8ms/step
 1/57 [..............................] - ETA: 6s

 7/57 [==>...........................] - ETA: 0s

14/57 [======>.......................] - ETA: 0s

20/57 [=========>....................] - ETA: 0s

27/57 [=============>................] - ETA: 0s

34/57 [================>.............] - ETA: 0s

41/57 [====================>.........] - ETA: 0s

47/57 [=======================>......] - ETA: 0s

57/57 [==============================] - ETA: 0s

57/57 [==============================] - 1s 8ms/step
Train Accuracy : 1.000
Test  Accuracy : 0.942

Classification Report : 
                    precision    recall  f1-score   support

       alt.atheism       0.98      0.93      0.95       319
     comp.graphics       0.88      0.98      0.93       389
   rec.motorcycles       0.97      0.99      0.98       398
         sci.space       0.94      0.92      0.93       394
talk.politics.misc       0.96      0.89      0.92       310

          accuracy                           0.94      1810
         macro avg       0.95      0.94      0.94      1810
      weighted avg       0.94      0.94      0.94      1810
# one-hot-encode clasess
oh_Y_test = np.eye(len(classes))[Y_test]

cm = metrics_explainer['confusionmatrix'](oh_Y_test, test_preds, selected_categories)
cm.visualize()
print(cm.report)
                    precision    recall  f1-score   support

       alt.atheism       0.98      0.93      0.95       319
     comp.graphics       0.88      0.98      0.93       389
   rec.motorcycles       0.97      0.99      0.98       398
         sci.space       0.94      0.92      0.93       394
talk.politics.misc       0.96      0.89      0.92       310

          accuracy                           0.94      1810
         macro avg       0.95      0.94      0.94      1810
      weighted avg       0.94      0.94      0.94      1810
../../_images/808008eb61f9b3b24fe567e7c00c78c88c799077f2cb055bf070eec77e2d6477.png
plotter = metrics_explainer['plot'](oh_Y_test, test_preds, selected_categories)
plotter.pr_curve()
plotter.roc_curve()
import re

X_batch_text = X_test_text[1:3]
X_batch = X_test[1:3]

print("Samples : ")
for text in X_batch_text:
    print(re.split(r"\W+", text))
    print()

preds_proba = model.predict(X_batch)
preds = preds_proba.argmax(axis=1)
tokens = re.split("\W+", X_batch_text[0].lower())

print("Actual    Target Values : {}".format([selected_categories[target] for target in Y_test[1:3]]))
print("Predicted Target Values : {}".format([selected_categories[target] for target in preds]))
print("Predicted Probabilities : {}".format(preds_proba.max(axis=1)))
Samples : 
['From', 'prb', 'access', 'digex', 'net', 'Pat', 'Subject', 'Re', 'Near', 'Miss', 'Asteroids', 'Q', 'Organization', 'Express', 'Access', 'Online', 'Communications', 'Greenbelt', 'MD', 'USA', 'Lines', '4', 'Distribution', 'sci', 'NNTP', 'Posting', 'Host', 'access', 'digex', 'net', 'TRry', 'the', 'SKywatch', 'project', 'in', 'Arizona', 'pat', '']

['From', 'cobb', 'alexia', 'lis', 'uiuc', 'edu', 'Mike', 'Cobb', 'Subject', 'Science', 'and', 'theories', 'Organization', 'University', 'of', 'Illinois', 'at', 'Urbana', 'Lines', '19', 'As', 'per', 'various', 'threads', 'on', 'science', 'and', 'creationism', 'I', 've', 'started', 'dabbling', 'into', 'a', 'book', 'called', 'Christianity', 'and', 'the', 'Nature', 'of', 'Science', 'by', 'JP', 'Moreland', 'A', 'question', 'that', 'I', 'had', 'come', 'from', 'one', 'of', 'his', 'comments', 'He', 'stated', 'that', 'God', 'is', 'not', 'necessarily', 'a', 'religious', 'term', 'but', 'could', 'be', 'used', 'as', 'other', 'scientific', 'terms', 'that', 'give', 'explanation', 'for', 'events', 'or', 'theories', 'without', 'being', 'a', 'proven', 'scientific', 'fact', 'I', 'think', 'I', 'got', 'his', 'point', 'I', 'can', 'quote', 'the', 'section', 'if', 'I', 'm', 'being', 'vague', 'The', 'examples', 'he', 'gave', 'were', 'quarks', 'and', 'continental', 'plates', 'Are', 'there', 'explanations', 'of', 'science', 'or', 'parts', 'of', 'theories', 'that', 'are', 'not', 'measurable', 'in', 'and', 'of', 'themselves', 'or', 'can', 'everything', 'be', 'quantified', 'measured', 'tested', 'etc', 'MAC', 'Michael', 'A', 'Cobb', 'and', 'I', 'won', 't', 'raise', 'taxes', 'on', 'the', 'middle', 'University', 'of', 'Illinois', 'class', 'to', 'pay', 'for', 'my', 'programs', 'Champaign', 'Urbana', 'Bill', 'Clinton', '3rd', 'Debate', 'cobb', 'alexia', 'lis', 'uiuc', 'edu', 'Nobody', 'can', 'explain', 'everything', 'to', 'anybody', 'G', 'K', 'Chesterton', '']


1/1 [==============================] - ETA: 0s

1/1 [==============================] - 0s 33ms/step
Actual    Target Values : ['sci.space', 'alt.atheism']
Predicted Target Values : ['sci.space', 'alt.atheism']
Predicted Probabilities : [0.9117346  0.78534925]

SHAP Partition Explainer

Visualize SHAP Values Correct Predictions

def make_predictions(X_batch_text):
    X_batch = vectorizer.transform(X_batch_text).toarray()
    preds = model.predict(X_batch)
    return preds

partition_explainer = feature_attributions_explainer.partitionexplainer(make_predictions, r"\W+", selected_categories)(X_batch_text)

Text Plot

partition_explainer.visualize()


[0]
outputs
alt.atheism
comp.graphics
rec.motorcycles
sci.space
talk.politics.misc


0.50.30.10.70.90.1504660.150466base value0.01075450.0107545falt.atheism(inputs)0.016 Arizona. 0.005 TRry 0.005 SKywatch 0.003 Miss 0.0 the 0.0 Re: -0.015 project -0.012 pat -0.01 digex. -0.009 Pat) -0.009 access. -0.008 prb@ -0.008 sci -0.008 Online -0.007 net ( -0.007 net -0.007 Access -0.007 Express -0.007 digex. -0.006 Communications, -0.006 access. -0.005 Asteroids ( -0.005 Greenbelt, -0.005 Distribution: -0.004 Near -0.004 USA -0.003 Subject: -0.003 Organization: -0.003 MD -0.003 NNTP- -0.002 From: -0.002 Posting- -0.002 Host: -0.002 Lines: -0.001 4 -0.0 in -0.0 Q)
inputs
-0.002
From:
-0.008
prb@
-0.006
access.
-0.007
digex.
-0.007
net (
-0.009
Pat)
-0.003
Subject:
0.0
Re:
-0.004
Near
0.003
Miss
-0.005
Asteroids (
-0.0
Q)
-0.003
Organization:
-0.007
Express
-0.007
Access
-0.008
Online
-0.006
Communications,
-0.005
Greenbelt,
-0.003
MD
-0.004
USA
-0.002
Lines:
-0.001
4
-0.005
Distribution:
-0.008
sci
-0.003
NNTP-
-0.002
Posting-
-0.002
Host:
-0.009
access.
-0.01
digex.
-0.007
net
0.005
TRry
0.0
the
0.005
SKywatch
-0.015
project
-0.0
in
0.016
Arizona.
-0.012
pat


[1]
outputs
alt.atheism
comp.graphics
rec.motorcycles
sci.space
talk.politics.misc


0.50.30.10.70.90.1504660.150466base value0.7853490.785349falt.atheism(inputs)0.081 alexia.lis.uiuc. 0.072 Debate cobb@ 0.051 alexia. 0.051 lis. 0.046 edu Nobody can explain 0.038 necessarily a religious term, 0.036 that God 0.035 his comments. He stated 0.031 is not 0.028 but 0.026 creationism, I've started 0.026 cobb@ 0.023 examples he gave were quarks 0.022 point -- I can quote 0.022 proven scientific fact. I 0.021 give explanation for events 0.021 think I got his 0.02 other scientific terms that 0.017 Mike Cobb) 0.017 Cobb "...and I won' 0.017 theories, without being a 0.016 Urbana -Bill Clinton 3rd 0.015 or 0.015 dabbling into a book called Christianity and 0.015 From: 0.015 the section if I'm being vague. The 0.013 could be used as 0.012 question that I had come from one of 0.01 t raise taxes on 0.009 the Nature of Science 0.004 uiuc.edu ( 0.003 Champaign- 0.003 are not measurable in 0.003 Subject: Science 0.002 and 0.001 programs." 0.001 my 0.001 for -0.029 measured, tested, etc.? -0.023 everything to anybody. G.K.Chesterton -0.016 per various -0.014 MAC -- **************************************************************** Michael -0.014 University of -0.013 19 As -0.013 A. -0.013 theories Organization: -0.01 Illinois class to pay -0.01 Illinois at -0.009 and continental plates. Are there -0.009 Urbana Lines: -0.007 threads on science -0.007 and -0.006 and of themselves, or can everything be quantified, -0.004 by JP Moreland. A -0.002 explanations of science or -0.002 theories that -0.001 parts of -0.001 the middle University of
inputs
0.015
From:
0.026
cobb@
0.051
alexia.
0.051
lis.
0.004 / 2
uiuc.edu (
0.017 / 2
Mike Cobb)
0.003 / 2
Subject: Science
0.002
and
-0.013 / 2
theories Organization:
-0.014 / 2
University of
-0.01 / 2
Illinois at
-0.009 / 2
Urbana Lines:
-0.013 / 2
19 As
-0.016 / 2
per various
-0.007 / 3
threads on science
-0.007
and
0.026 / 4
creationism, I've started
0.015 / 7
dabbling into a book called Christianity and
0.009 / 4
the Nature of Science
-0.004 / 4
by JP Moreland. A
0.012 / 8
question that I had come from one of
0.035 / 4
his comments. He stated
0.036 / 2
that God
0.031 / 2
is not
0.038 / 4
necessarily a religious term,
0.028
but
0.013 / 4
could be used as
0.02 / 4
other scientific terms that
0.021 / 4
give explanation for events
0.015
or
0.017 / 4
theories, without being a
0.022 / 4
proven scientific fact. I
0.021 / 4
think I got his
0.022 / 4
point -- I can quote
0.015 / 8
the section if I'm being vague. The
0.023 / 5
examples he gave were quarks
-0.009 / 5
and continental plates. Are there
-0.002 / 4
explanations of science or
-0.001 / 2
parts of
-0.002 / 2
theories that
0.003 / 4
are not measurable in
-0.006 / 8
and of themselves, or can everything be quantified,
-0.029 / 3
measured, tested, etc.?
-0.014 / 2
MAC -- **************************************************************** Michael
-0.013
A.
0.017 / 4
Cobb "...and I won'
0.01 / 4
t raise taxes on
-0.001 / 4
the middle University of
-0.01 / 4
Illinois class to pay
0.001
for
0.001
my
0.001
programs."
0.003
Champaign-
0.016 / 4
Urbana -Bill Clinton 3rd
0.072 / 2
Debate cobb@
0.081 / 3
alexia.lis.uiuc.
0.046 / 4
edu Nobody can explain
-0.023 / 6
everything to anybody. G.K.Chesterton

Bar Plots

Bar Plot 1

bar_values = partition_explainer.shap_values[:,:, selected_categories[preds[0]]].mean(axis=0)
order = partition_explainer.shap.Explanation.argsort.flip
partition_explainer.bar_plot(bar_values, max_display=15, order=order)
Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
../../_images/ce957a0384118349a277476203826720f1f55491d6ad43333a4b1b8b0e7f8174.png

Bar Plot 2

bar_values = partition_explainer.shap_values[0,:, selected_categories[preds[0]]]
partition_explainer.bar_plot(bar_values, max_display=15, order=order)
../../_images/475972ccd54a830be5197b08ca8e5bd7fed91116f7fad2f7519e463300ce90ef.png

Bar Plot 3

bar_values = partition_explainer.shap_values[:,:, selected_categories[preds[1]]].mean(axis=0)
partition_explainer.bar_plot(bar_values, max_display=15, order=order)
../../_images/431e476f8c655b29d161cf7ccc6aa9d106c4759580e61884d9e9e6261114bca6.png

Bar Plot 4

bar_values = partition_explainer.shap_values[1,:, selected_categories[preds[1]]]
partition_explainer.bar_plot(bar_values, max_display=15, order=order)
../../_images/3de8f005d1723556f28d1eec13aa6211be37d60db4655b4f4cd718a60446b4d9.png

Waterfall Plots

Waterfall Plot 1

waterfall_values = partition_explainer.shap_values[0][:, selected_categories[preds[0]]]
partition_explainer.waterfall_plot(waterfall_values, max_display=15)
../../_images/b1ef655d68d8a1926ba5c2d68a114bbb5c4ff72d41b9ae117c22c34e653ffcd5.png

Waterfall Plot 2

waterfall_values = partition_explainer.shap_values[1][:, selected_categories[preds[1]]]
partition_explainer.waterfall_plot(waterfall_values, max_display=15)
../../_images/132c34e81433263f83b36e272276b8c0433a41954c8b5b9e2df14d45e089d09d.png

Force Plot

force_base_values = partition_explainer.shap_values.base_values[0][preds[0]]
force_values = partition_explainer.shap_values[0][:, preds[0]].values
partition_explainer.force_plot(force_base_values, force_values,
                               feature_names=tokens[:-1], out_names=selected_categories[preds[0]])
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.